package com.multiThread;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicStampedReference;

public class AtomicStampedReferenceTest {
    /**
     * corePoolSize具体要根据公式计算得出
     *
     */
    static int corePoolSize = 2;
    static int maximumPoolSize = 10;
    static ThreadPoolExecutor threadPool = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 10, TimeUnit.MINUTES, new ArrayBlockingQueue<Runnable>(10000));

    /**
     * 这里实现了多线程情况操作数据库，并将结果累加
     * 由于金额是BigDecimal类型的，所以我们用到了AtomicStampedReference实现
     * @param args
     */
    public static void main(String[] args)  {
        // 记报错的线程个数
        AtomicInteger errorCount = new AtomicInteger(0);

        // 要被操作的公共数据【AtomicStampedReference能避免引用类型对象的ABA问题】初始值为0，初始stamp也会0
        AtomicStampedReference<BigDecimal> sumCashPriceStamped = new AtomicStampedReference<>(BigDecimal.ZERO, 0);
        // 开始时间
        long s = System.currentTimeMillis();
        // 初始化数据[数字模拟数据库的id]
        List list = new ArrayList();
        for (int i = 0; i < 2001; i++) {
            list.add(i);
        }
        // 每个任务的大小
        int taskSize = 100;
        // 需要线程数据
        int threadNum = list.size() / taskSize;
        if(threadNum * taskSize == list.size()){
            threadNum -= 1;
            System.out.println("123456789");
        }
        CountDownLatch cdl = new CountDownLatch(threadNum);
        // 循环执行任务
        for (int i = 0; i <= threadNum; i++) {
            int fromIndex = i * taskSize;
            int toIndex = (i + 1) * taskSize;
            if (toIndex > list.size()) {
                toIndex = list.size();
            }
            // 以下必须重新new一个ArrayList来组装subList返回的结果。防止内存溢出[直接引用的每个对象里包含着父级大对象]
            List<Integer> idList = new ArrayList<>(list.subList(fromIndex, toIndex));
            threadPool.execute(new OrderProfitRateTask(cdl,errorCount,idList,sumCashPriceStamped));
        }

        try {
            // 等待所有线程执行完成（计算器值为0时）唤醒
            cdl.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        if(errorCount.get() > 0){
            System.out.println("报错了");
        }
        // 打印结果
        System.out.println(sumCashPriceStamped.getReference());
        System.out.println(sumCashPriceStamped.getStamp());
        // 关闭线程池
        threadPool.shutdown();
        long   e = System.currentTimeMillis();
        System.out.println("共耗时："+(s-e));
    }


    private static class OrderProfitRateTask implements Runnable {
        // 这些都是共享变量
        CountDownLatch countDountLatch;
        AtomicInteger errorCount;
        List idList;
        AtomicStampedReference<BigDecimal> sumCashPriceStamped;

        public OrderProfitRateTask(CountDownLatch countDountLatch,
                                   AtomicInteger errorCount,
                                   List<Integer> idList,
                                   AtomicStampedReference<BigDecimal> sumCashPriceStamped) {
            this.countDountLatch = countDountLatch;
            this.errorCount = errorCount;
            this.idList = idList;
            this.sumCashPriceStamped = sumCashPriceStamped;
        }

        @Override
        public void run() {
            System.out.println("线程名称：" + Thread.currentThread().getName());
            try {
                // 用id去数据库查数据伪代码
                // select sum(money)  from user where id in(idList);
                // 模拟耗时查询
                Thread.sleep(1000);
                // 这里直接假设查回来的sum值都是1
                BigDecimal sum = BigDecimal.ONE;


                // 循环CAS直到成功
                while (true) {
                    // 获取旧时间戳
                    int timestamp = sumCashPriceStamped.getStamp();
                    // 获取旧值
                    BigDecimal m = sumCashPriceStamped.getReference();
                    // CAS
                    if (sumCashPriceStamped.compareAndSet(m, m.add(sum), timestamp, timestamp + 1)) {
                        break;
                    }
                    // CAS失败继续循环
                }
            } catch (Exception e) {
                e.printStackTrace();
                errorCount.incrementAndGet();
            } finally {
                // 该线程执行完，计数器减1
                countDountLatch.countDown();
            }
        }
    }
}
